# Copyright (c) 2024 Bytedance Ltd. and/or its affiliates
# This file is part of ByteQC.
#
# Licensed under the Apache License, Version 2.0 (the "License")
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https: // www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from pyscf import gto
from byteqc.cuobc import scf
from byteqc.embyte.Solver import GPU_CCSDSolver, GPU_MP2Solver
from byteqc.embyte.Localization import iao
from byteqc.embyte.Tools.fragment import Fragment
from byteqc import embyte
import pyscf
import time
import numpy
import os
import cupy
cupy.cuda.set_pinned_memory_allocator(None)


def get_fragments(orb_list, equi_list):
    fragments = [
        Fragment(
            orb_list[i],
            equi_list[i],
            ['main']).to_dict() for i in range(
            len(equi_list))]
    return fragments


def get_atom_frag_list_eq_list_water_cluster(mol):
    atom_coords = mol.atom_coords(unit='A')
    atom_frag_list = []
    atom_symbols = []
    for i in range(mol.natm):
        atom_symbols.append(mol.atom_pure_symbol(i))

    for i in range(mol.natm):
        if atom_symbols[i] == "O":
            atom_frag_tmp = []
            atom_frag_tmp.append(i)
            atom_coords_tmp = atom_coords - atom_coords[i]
            H2 = numpy.linalg.norm(atom_coords_tmp, axis=1).argsort()[[1, 2]]
            for H_ind in H2:
                assert atom_symbols[H_ind] == 'H'
                atom_frag_tmp.append(H_ind)
            atom_frag_list.append(atom_frag_tmp)

    eq_list = list(range(mol.natm // 3))
    return atom_frag_list, eq_list


def water_dimmer_mol(basis_set):

    mol = gto.M()
    mol.atom = '''
        O	0.00006	1.52180 0.00000
        H	-0.09847 0.55394 0.00000
        H	-0.90617 1.85309 0.00000
        O	0.00006	-1.39463 0.00000
        H	0.50188	-1.71221 0.76271
        H	0.50188	-1.71221 -0.76271
    '''
    mol.basis = basis_set
    mol.build()
    return mol


if __name__ == '__main__':

    print('++++++++++++++++++++++++++++====================++++++++++++++++++++++++++++====================')
    print(pyscf.__version__)
    print(pyscf.__file__)

    threshold = [6.0]
    threshold = [10 ** -th for th in threshold]

    if_MP2 = True

    mol_basis = 'aug-cc-pVTZ'

    mol = water_dimmer_mol(mol_basis)
    mf = scf.RHF(mol)

    logdir = os.path.join(
        os.path.dirname(
            os.path.abspath(__file__)),
        f'result/')
    if not os.path.exists(logdir):
        os.makedirs(logdir)

    # HF_chkfile is generated by using pyscf.scf.RHF without density fiting.
    # JK_file is generated by using pyscf.scf.RHF.get_veff without density
    # fiting.

    chkfile = os.path.join(logdir, 'HF_chkfile.chk')
    jk_file = os.path.join(logdir, 'JK_file.npy')
    eri_path = None

    mf.chkfile = chkfile
    mf.kernel()
    veff = mf.get_veff()
    numpy.save(jk_file, veff)

    assert os.path.exists(chkfile) and os.path.exists(jk_file)

    logfile = os.path.join(logdir, f'SIE_result')

    tot_t = time.time()

    SIE_class = embyte.Framework.SIE.SIE_kernel(logfile, chkfile)

    if if_MP2:
        SIE_class.electronic_structure_solver = GPU_MP2Solver
    else:
        SIE_class.electronic_structure_solver = GPU_CCSDSolver

    SIE_class.electron_localization_method = iao

    SIE_class.RDM = False
    SIE_class.in_situ_T = False

    SIE_class.aux_basis = f'{mol.basis}-ri'

    SIE_class.jk_file = jk_file
    SIE_class.eri = eri_path

    atom_list_frag, equi_list = get_atom_frag_list_eq_list_water_cluster(mol)
    SIE_class.threshold = threshold

    orb_list = embyte.Tools.fragment.from_atom_to_orb_iao(mol, atom_list_frag)

    fragments = get_fragments(orb_list, equi_list)

    SIE_class.simulate(mol, chkfile, fragments)

    # Runing MP2 with using CI-coefficients form to calculate energy
    # if_MP2 = True, SIE_class.RDM = False
    # -------------------------------------------------
    # Correlation energy from CI-coefficients: [-0.5716659593618497]
    # -------------------------------------------------
    # Total energy: [-152.6969751057922]

    # Runing MP2 with global 1-RDM and in-cluster 2-RDM
    # if_MP2 = True, SIE_class.RDM = True
    # -------------------------------------------------
    # Correlation energy from 1-RDM: [0.5716545316301486]
    # Correlation energy from 2-RDM: [-1.1433319187236994]
    # Correlation energy from RDM: [-0.5716773870935509]
    # Correlation energy from CI-coefficients: [-0.5716659593618497]
    # -------------------------------------------------
    # Total RDM energy: [-152.69698653352393]

    # Runing CCSD with using CI-coefficients form to calculate energy
    # if_MP2 = False, SIE_class.RDM = False, SIE_class.in_situ_T = False
    # -------------------------------------------------
    # Correlation energy from CI-coefficients: [-0.580732854160309]
    # -------------------------------------------------
    # Total energy: [-152.70604200059068]

    # Runing CCSD with global 1-RDM and in-cluster 2-RDM
    # if_MP2 = False, SIE_class.RDM = True, SIE_class.in_situ_T = False
    # -------------------------------------------------
    # Correlation energy from 1-RDM: [0.593443000351912]
    # Correlation energy from 2-RDM: [-1.1741946365773588]
    # Correlation energy from RDM: [-0.5807516362254468]
    # Correlation energy from CI-coefficients: [-0.580732854160309]
    # -------------------------------------------------
    # Total RDM energy: [-152.7060607826558]

    # Runing CCSD(T) with using CI-coefficients form to calculate energy
    # if_MP2 = False, SIE_class.RDM = False, SIE_class.in_situ_T = True
    # -------------------------------------------------
    # Correlation energy from CI-coefficients: [-0.580732854160309]
    # ====================================
    # CCSD(T) correction is [-0.01816449]
    # Correlation energy from CI-coefficients with (T) correction: [-0.5988973479514096]
    # ====================================
    # -------------------------------------------------
    # Total energy: [-152.72420649438178]

    # Runing CCSD(T) with global 1-RDM and in-cluster 2-RDM
    # if_MP2 = False, SIE_class.RDM = True, SIE_class.in_situ_T = True
    # -------------------------------------------------
    # Correlation energy from 1-RDM: [0.593443000351912]
    # Correlation energy from 2-RDM: [-1.1741946365773588]
    # Correlation energy from RDM: [-0.5807516362254468]
    # Correlation energy from CI-coefficients: [-0.580732854160309]
    # ====================================
    # CCSD(T) correction is [-0.01816449]
    # Correlation energy from RDM with (T) correction: [-0.5989161300165474]
    # Correlation energy from CI-coefficients with (T) correction: [-0.5988973479514096]
    # ====================================
    # -------------------------------------------------
    # Total RDM energy: [-152.7242252764469]
